-
Notifications
You must be signed in to change notification settings - Fork 599
unify moe implementation for llama4 and deepseek_v3 #1534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
danielvegamyhre
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, left a couple of minor comments/questions.
| import torch.nn as nn | ||
| from torch.distributed._functional_collectives import all_to_all_single_autograd | ||
|
|
||
| # from torch.distributed._functional_collectives import all_to_all_single_autograd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove commented code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is intentional -- we should restore this implementation after bug is fixed. I reorganized the code a bit to make it clearer.
| @staticmethod | ||
| def forward(ctx, x, out_splits, in_splits, group): | ||
| if isinstance(out_splits, torch.Tensor): | ||
| out_splits = out_splits.tolist() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
won't tolist() cause d2h sync? is this okay / intentional in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will. This is a temporary fix, but currently in EP there are multiple places with d2h sync. I'm working on another implementation to kill them.
ruisizhang123
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for the refactor.
85dc2ad to
16ad9f5
Compare
torchtitan/models/moe.py
Outdated
|
|
||
| @dataclass | ||
| class MoEArgs: | ||
| moe_enabled: bool = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we need to have moe_enabled in MoEArgs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't see anywhere this parameter is false
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, removed
wwwjn
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Nice refactor
|
(sry, I am leaving for vacation and lazy to open PR).
we fixed this by adding a in where We also have a better version of the bias update that only needs one reduce, you can check the code here. A tricks here does not affect to Bias update, but we need to know, once we have activation checkpoint. It will be called more than once, so the actual stats value of |
Given the complexity of MoE and EP modules This PR 1. creates `torchtitan/models/moe.py` as the central moe implementation (this is similar to why we have `torchtitan/models/attention.py`) 2. creates `torchtitan/distributed/expert_parallel.py` as the central EP implementation 3. rename `torchtitan/distributed/pipeline.py` -> `torchtitan/distributed/pipeline_parallel.py` to be consistent with EP 4. apply temporary fix by @rakkit pytorch#1467 before the memory leak issue with AC + PT-D all_to_all_single_autograd is fixed (cc @soulitzer)
Given the complexity of MoE and EP modules This PR 1. creates `torchtitan/models/moe.py` as the central moe implementation (this is similar to why we have `torchtitan/models/attention.py`) 2. creates `torchtitan/distributed/expert_parallel.py` as the central EP implementation 3. rename `torchtitan/distributed/pipeline.py` -> `torchtitan/distributed/pipeline_parallel.py` to be consistent with EP 4. apply temporary fix by @rakkit pytorch#1467 before the memory leak issue with AC + PT-D all_to_all_single_autograd is fixed (cc @soulitzer)
issue pointed out in #1534 (comment) pytorch/pytorch#160285 solution given by @rakkit in #1534 (comment)
|
Nice @rakkit, we found the same issue with the ep grads being off by a factor. I was finding that set_reduce_scatter_divide_factor errored when using an mp policy, though. Surprised you didn't hit that? Think I saw you're on |
|
lol @garrett361 thanks for the info. I did not see the issue on both Torch 2.6 and 2.7.1. To clarify I only test the default mp set(mixed_precision_param=bf16 and mixed_precision_reduce=fp32) |

Given the complexity of MoE and EP modules
This PR
torchtitan/models/moe.pyas the central moe implementation (this is similar to why we havetorchtitan/models/attention.py)torchtitan/distributed/expert_parallel.pyas the central EP implementationtorchtitan/distributed/pipeline.py->torchtitan/distributed/pipeline_parallel.pyto be consistent with EP